# -*- coding: utf-8 -*-
"""scRNAseq_for_Cre-loxP.ipynb

Automatically generated by Colab.

# Installing packages
"""

!pip install scanpy
!pip install gff3
!pip install umap-learn[plot]
!pip install gffutils

!pip install leidenalg

import pandas as pd
import glob
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import umap
import umap.plot

from gff3 import Gff3
import gffutils
import scanpy as sc

!pip install pyarrow

"""# R3 Pre-processing and filtering expression data"""

pos = pd.read_csv('alexp1CRE_cell_gene.csv')

# Convert to feather
pos.to_feather('exp1CRE_data.feather')

pos_feather = pd.read_feather('exp1CRE_data.feather')

pos_feather.set_index('gene_id', inplace=True) # Setting the column with the heading 'gene_id' as the row labels of the dataframe

filename = "Ecoli_mg1655_reporters_wCRE.gff3"
database_filename = "E_coli_MG1655_db_CRE" # choosing a filename to form a database within gffutils package
 # creating a gffutils database to use gff file data
db = gffutils.create_db(filename, database_filename, merge_strategy="warning")

database_filename = "E_coli_MG1655_db_CRE"
# Load the existing database into the 'db' object
db = gffutils.FeatureDB(database_filename)

genes = [i for i in [f.id for f in db.all_features()] if 'gene' in i]

# Initialize lists to store gene information
gene_ids = []
gene_biotypes = []
gene_names = []

# Iterate through each gene
for gene in genes:
    attributes = db[gene].attributes  # Access the attributes dictionary directly

    # Get the gene ID and biotype (these are always present)
    gene_id = attributes['ID'][0]
    biotype = attributes['biotype'][0]

    # Check for 'Name' attribute, if missing, use 'ID' as the name
    if 'Name' in attributes:
        gene_name = attributes['Name'][0]
    else:
        gene_name = gene_id  # Use 'ID' as the name if 'Name' is missing

    # Append the extracted data to the respective lists
    gene_ids.append(gene_id)
    gene_biotypes.append(biotype)
    gene_names.append(gene_name)

# Create a DataFrame from the collected data
genes_df = pd.DataFrame({
    'ID': gene_ids,
    'biotype': gene_biotypes,
    'Name': gene_names
})

# Create a list of IDs for genes whose biotype is 'tRNA' or 'rRNA'
excluded_gene_ids = genes_df[genes_df['biotype'].isin(['tRNA', 'rRNA'])]['ID'].tolist()

# Filter the DataFrame to exclude rows where biotype is 'tRNA' or 'rRNA'
filtered_genes_df = genes_df[~genes_df['biotype'].isin(['tRNA', 'rRNA'])]

# Display the filtered DataFrame
print(filtered_genes_df)

# Display the list of excluded gene IDs
print(excluded_gene_ids)

pos_feather_raw = pos_feather.copy() # retaining a copy of raw data matrix before dropping rows
# corresponding to tRNA rRNA genes

pos_feather = pos_feather.drop(labels=excluded_gene_ids, axis=0)

# Reset the index to make 'gene_id' a column
pos_feather.reset_index(inplace=True)

# Create a mapping dictionary from filtered_genes_df of all non-tRNA non-rRNA genes
id_to_name_map = filtered_genes_df.set_index('ID')['Name'].to_dict()

# Replace gene_id in pos_feather with corresponding gene_names
pos_feather['gene_id'] = pos_feather['gene_id'].map(id_to_name_map)

# Remove the 'gene:' prefix from the 'gene_id' column only if it exists
pos_feather['gene_id'] = pos_feather['gene_id'].str.replace('^gene:', '', regex=True)

# Set the 'gene_id' column back as the index
pos_feather.set_index('gene_id', inplace=True)

pos = pos_feather.T # transpose the dataframe to have each gene form a separate column

pos.loc[:,'Total UMIs'] = pos.sum(axis=1) # add another new column at the end with the sum of UMI's detected for each cell
# axis=1 means summing over all the COLUMNS for each row

pos['Total UMIs'].sum() # excluding tRNA rRNA

# with Cre
read_rank = list(pos['Total UMIs'])
read_rank.sort(reverse=True)
x_axis=range(1,len(read_rank)+1)
plt.loglog(x_axis,read_rank)   # Plot number of reads of each barcode
plt.xlabel('Barcodes')
plt.ylabel('mRNA UMI')
plt.xlim([1, len(read_rank)+1])

# Creating the scatter plot
plt.figure(figsize=(9, 6))
sns.scatterplot(data=pos, x='Total UMIs', y='mScarlet')

# Add minor ticks at intervals of 500
plt.minorticks_on()
plt.tick_params(axis='x', which='minor', length=4, width=1, direction='out')

# Set the minor ticks manually
plt.gca().xaxis.set_minor_locator(plt.MultipleLocator(500))

# Adding labels and title
plt.xlabel('Total mRNA UMI')
plt.ylabel('mScarlet UMI')
plt.title('Distribution of mScarlet vs total mRNA UMI')

# Display the plot
plt.show()

# adding a lower mRNA cut off of 200 and upper cut off of 7000 --> 16575 cells
pos_200_7k = pos[(pos['Total UMIs'] > 200) & (pos['Total UMIs'] < 7000)]

mS_count = (pos['mScarlet'] > 0).sum()
mS_count

# Creating the scatter plot
plt.figure(figsize=(7.5, 6))
sns.scatterplot(data=pos_200_7k, x='Total UMIs', y='mScarlet')

# Add minor ticks at intervals of 250
plt.minorticks_on()
plt.tick_params(axis='x', which='minor', length=4, width=1, direction='out')

# Set the minor ticks manually
plt.gca().xaxis.set_minor_locator(plt.MultipleLocator(200))

# Adding labels and title
plt.xlabel('Total mRNA UMI')
plt.ylabel('mScarlet UMI')
plt.title('Distribution of mScarlet vs total UMI')

# Display the plot
plt.show()

# Add minor ticks at intervals of 250
plt.minorticks_on()
plt.tick_params(axis='y', which='minor', length=4, width=1, direction='out')

# Set the minor ticks manually
plt.gca().yaxis.set_minor_locator(plt.MultipleLocator(250))
sns.violinplot(data=pos_200_7k, y='Total UMIs', cut=0)

#let's plot the distribution of the number of detected genes of the cells w/Cre
pos_200_7k['gene_count'] = pos_200_7k.astype(bool).sum(axis=1) # adding the genes detected for each row by checking those that have a non-zero value in a new column

# Add minor ticks at intervals of 50
plt.minorticks_on()
plt.tick_params(axis='y', which='minor', length=4, width=1, direction='out')

# Set the minor ticks manually
plt.gca().yaxis.set_minor_locator(plt.MultipleLocator(50))

sns.violinplot(data=pos_200_7k, y='gene_count', cut=0)

# Creating the scatter plot w/Cre
plt.figure(figsize=(7.5, 6))
sns.scatterplot(data=pos_200_7k, x='Total UMIs', y='gene_count')

# Add minor tick
plt.minorticks_on()
plt.tick_params(axis='y', which='minor', length=4, width=1, direction='out')
plt.tick_params(axis='x', which='minor', length=4, width=1, direction='out')

# Set the minor ticks manually
plt.gca().yaxis.set_minor_locator(plt.MultipleLocator(50))
plt.gca().xaxis.set_minor_locator(plt.MultipleLocator(200))

# Adding labels and title
plt.xlabel('Total mRNA UMI')
plt.ylabel('mRNA Gene count')
plt.title('Distribution of Gene count vs total UMI')

# Display the plot
plt.show()

# adding a more stringent upper mRNA UMI cutoff of 3400 --> 16449 cells
pos_200_34k = pos_200_7k[pos_200_7k['Total UMIs'] < 3400]

# adding a more stringent upper gene cutoff of 760 --> 16532 cells
pos_gene_760 = pos_200_7k[pos_200_7k['gene_count'] < 760]

mS_count = (pos_200_34k['mScarlet'] > 0).sum()
mS_count

mS_count = (pos_gene_760['mScarlet'] > 0).sum()
mS_count

# Creating the scatter plot
plt.figure(figsize=(7.5, 6))
sns.scatterplot(data=pos_200_34k, x='gene_count', y='mScarlet')

# Add minor ticks at intervals of 20
plt.minorticks_on()
plt.tick_params(axis='x', which='minor', length=4, width=1, direction='out')

# Set the minor ticks manually
plt.gca().xaxis.set_minor_locator(plt.MultipleLocator(20))

# Adding labels and title
plt.xlabel('Gene count')
plt.ylabel('mScarlet')
plt.title('Distribution of mScarlet vs gene count')

# Display the plot
plt.show()

# Creating the scatter plot
plt.figure(figsize=(7.5, 6))
sns.scatterplot(data=pos_gene_760, x='Total UMIs', y='mScarlet')

# Add minor ticks at intervals of 20
plt.minorticks_on()
plt.tick_params(axis='x', which='minor', length=4, width=1, direction='out')

# Set the minor ticks manually
plt.gca().xaxis.set_minor_locator(plt.MultipleLocator(200))

# Adding labels and title
plt.xlabel('mRNA UMI')
plt.ylabel('mScarlet')
plt.title('Distribution of mScarlet vs total mRNA UMI')

# Display the plot
plt.show()

pos_1 = pos_gene_760[(pos_gene_760['Total UMIs'] < 3400)]

pos_2 = pos_200_34k[(pos_200_34k['gene_count'] < 760)]

mS_count = (pos_1['mScarlet'] > 0).sum()
mS_count

mS_count = (pos_2['mScarlet'] > 0).sum()
mS_count

# applying final cutoff
pos_final = pos_200_7k[(pos_200_7k['Total UMIs'] < 3400) & (pos_200_7k['gene_count'] < 760)]

mS_count = (pos_final['mScarlet'] > 0).sum()
mS_count

pos_final.iloc[:, :-2].to_csv('pos_760_34k_beforenorm.csv') #save unnormalized data of the merged & filtered dataframe in csv file. -2 to exclude the last 2 columns by position

"""# R3 sctransform analysis with Cre 


"""

pos_sctscaleddata = pd.read_feather('data_sctransformed.feather')

pos_sctscaleddata.set_index('Unnamed: 0', inplace=True)

pos_sctscaleddata.index.name = 'gene'

pos_sctscaleddata

pos_100_sctscaleddata = pos_sctscaleddata.T
pos_100_sctscaleddata

pos_100_sctscaleddata.to_csv('exp1_wcre_sct.csv') #save unnormalized data of the merged & filtered dataframe in csv file. -2 to exclude the last 2 columns by position

adata = sc.read_csv('exp1_wcre_sct.csv',first_column_names=True)

adata

"""Preparing the plotting of corrected UMI counts of the most variable genes by saving these counts as a new layer to the same object."""

pearson_genes = list(adata.var_names)

# Load the corrected UMI count matrix into a pandas DataFrame
corrumi = pd.read_csv('sct_corrected_counts.csv', index_col=0)

corrumi_scanpy = corrumi.T

corrumi_scanpy

corrumi_filtered = corrumi_scanpy[pearson_genes]

# This format is required for adding the data as a layer in the AnnData object.
corrumi_layer = corrumi_filtered.to_numpy()

# Check the dimensions to ensure it matches the adata shape
assert corrumi_layer.shape == adata.shape, "Shape mismatch between adata and filtered UMI count matrix"

adata.layers["corrected_umi"] = corrumi_layer

"""End of corrected UMI count integration into the same object."""

# PCA on pearson residuals from sctransform
sc.tl.pca(adata, svd_solver='arpack', n_comps=50)

adata

sc.pl.pca(adata, size = 50)

sum(adata.uns['pca']['variance_ratio'])

# Plot the variance ratio for the different PCs

sc.pl.pca_variance_ratio(adata, n_pcs=50, log=True)

# rank genes according to contribution to different PCs - Pearson residuals

sc.pl.pca_loadings(adata, components=[1,2,3,4,5,6])

# Computing neighborhood graph for UMAP dimensionality reduction
sc.pp.neighbors(adata) # UMAP is based on the neighbor graph; we'll compute this first

# neighbors = 15
sc.tl.umap(adata)
sc.pl.umap(adata)

# Using the igraph implementation and a fixed number of iterations can be significantly faster, especially for larger datasets
sc.tl.leiden(adata, resolution = 0.15, flavor = 'leidenalg')

#Resolution = 0.15 Leiden High-res image
#Generate the UMAP plot without displaying it
umap_plot = sc.pl.umap(
    adata,
    color='leiden',
    show=False  # Don't show the plot immediately, so we can customize and save it later
)

# Save the plot as a high-resolution PNG (raster format)
# umap_plot.figure.savefig('umap_high_res.png', dpi=300)  # Set DPI for high resolution

# Save the plot as an SVG (vector format) for publication-quality scalability
umap_plot.figure.savefig('umap_high_res.svg')  # SVG does not need DPI as it's a vector format

# resolution = 0.15 Leiden
sc.tl.rank_genes_groups(adata, 'leiden', method='wilcoxon')
sc.pl.rank_genes_groups(adata, n_genes=20, sharey=False, fontsize = 12, save='marker_genes_0.15_high_res.svg')

# Count the size of each cluster for resolution = 0.15
cluster_sizes = adata.obs['leiden'].value_counts()

# Show the sizes
print(cluster_sizes)

# Sum all gene expression values for each cell using the corrected UMI layer
adata.obs['total_mRNA'] = adata.layers['corrected_umi'].sum(axis=1)

# Plot the total expression per cell grouped by the Leiden clusters HIGH RES
violinplot = sc.pl.violin(
    adata,
    ['total_mRNA'],
    groupby='leiden',
    jitter=True,          # Adds jitter to better visualize individual points
    size=1,               # Increase point size for individual cells (jitter)
    scale='width',        # Scales violins by their width
    linewidth=1.5,        # Increase the line thickness for violins
    show=False            # Prevent the plot from being shown immediately
)

# Save as a vector format for publication (scalable)
violinplot.figure.savefig('violin_plot_high_res.svg')  # Save as SVG for scalability

import pandas as pd
from scipy.stats import ranksums
from statsmodels.stats.multitest import multipletests

# Get the total expression values for all clusters
total_expression = adata.obs[['total_mRNA', 'leiden']]

# Convert 'leiden' column to categorical type if it's not already
total_expression['leiden'] = total_expression['leiden'].astype('category')

# Separate the total expression values of a given cluster
cluster_0_expr = total_expression[total_expression['leiden'] == '3']['total_mRNA']

# Initialize a list to store results
results = []

# Compare cluster 3 with each other cluster
for cluster in total_expression['leiden'].cat.categories:
    if cluster != '3':  # skip comparison with itself
        cluster_expr = total_expression[total_expression['leiden'] == cluster]['total_mRNA']
        # Perform the Wilcoxon rank-sum test
        stat, p_value = ranksums(cluster_0_expr, cluster_expr)
        results.append({'Cluster': cluster, 'Z-score': stat, 'p-value': p_value})

# Convert results to a DataFrame for easier visualization
results_df = pd.DataFrame(results)

# Apply the Benjamini-Hochberg correction
results_df['Adjusted p-value'] = multipletests(results_df['p-value'], method='fdr_bh')[1]

# Comparing the total mRNA count of one cluster to other clusters to check for significant difference.
print(results_df)

# List of genes you want to analyze
genes = ['mScarlet', 'msGFP2', 'Cre']

# Access the corrected count layer for the selected genes
gene_expression = adata[:, genes].layers['corrected_umi'].toarray()

# Binarize the gene expression to identify positively expressing cells (non-zero counts)
posexp = gene_expression > 0

# Get the clusters for each cell
clusters = adata.obs['leiden']

# Create a dataframe for easier grouping
df = pd.DataFrame({
    'cluster': clusters,
    'pos_mScarlet': posexp[:, 0],  # For gene 1
    'pos_msGFP2': posexp[:, 1],  # For gene 2
    'pos_Cre': posexp[:, 2]   # For gene 3
})

# Calculate the fraction of positive cells in each cluster for each gene
frac_mS = df.groupby('cluster')['pos_mScarlet'].mean()
frac_GFP = df.groupby('cluster')['pos_msGFP2'].mean()
frac_cre = df.groupby('cluster')['pos_Cre'].mean()

# Combine the results into a single dataframe for easier viewing
frac_pos = pd.DataFrame({
    'mScarlet': frac_mS,
    'msGFP2': frac_GFP,
    'Cre': frac_cre
})

print(frac_pos)

# Calculate the fraction of positive cells in each cluster for each gene
count_mS = df.groupby('cluster')['pos_mScarlet'].sum()
count_GFP = df.groupby('cluster')['pos_msGFP2'].sum()
count_cre = df.groupby('cluster')['pos_Cre'].sum()

# Combine the results into a single dataframe for easier viewing
count_pos = pd.DataFrame({
    'mScarlet': count_mS,
    'msGFP2': count_GFP,
    'Cre': count_cre
})

print(count_pos)

# Create a dataframe for easier grouping
df2 = pd.DataFrame({
    'cluster': clusters,
    'expression_gene1': gene_expression[:, 0],  # Expression values for gene1
    'expression_gene2': gene_expression[:, 1],
    'expression_gene3': gene_expression[:, 2] # Expression values for gene2
})

# Calculate the mean expression of each gene in each cluster, including zeros
mean_expression_gene1 = df2.groupby('cluster')['expression_gene1'].mean()
mean_expression_gene2 = df2.groupby('cluster')['expression_gene2'].mean()
mean_expression_gene3 = df2.groupby('cluster')['expression_gene3'].mean()

# Combine the results into a single dataframe for easier viewing
mean_expression = pd.DataFrame({
    'mScarlet': mean_expression_gene1,
    'msGFP2': mean_expression_gene2,
    'Cre': mean_expression_gene3
})

print(mean_expression)

# Read the CSV file into a pandas DataFrame
goi = pd.read_csv('Genes_of_interest.csv')

dna_repair = goi['DNA repair'].dropna().tolist()
pp = goi['ppGpp synthesis'].dropna().tolist()
psp_reg = goi['Psp regulon'].dropna().tolist()
ta = goi['TA'].dropna().tolist()
trans = goi['Proteostasis'].dropna().tolist()
stress = goi['Stress response'].dropna().tolist()
metabolism = goi['Metabolism'].dropna().tolist()
efflux = goi['Drug efflux'].dropna().tolist()

dna_synth = goi['DNA synthesis'].dropna().tolist()
phoPQon = goi['phoPQ_on'].dropna().tolist()
phoPQoff = goi['phoPQ_off'].dropna().tolist()
gadE_reg = goi['gadE'].dropna().tolist()
gadX_reg = goi['gadX'].dropna().tolist()
arcABon = goi['arcAB_on'].dropna().tolist()
arcABoff = goi['arcAB_off'].dropna().tolist()
ribo30s = goi['30S'].dropna().tolist()
ribo50s = goi['50S'].dropna().tolist()
dksAon = goi['dksA_on'].dropna().tolist()
dksAoff = goi['dksA_off'].dropna().tolist()

# Convert the first letter of each gene name to lowercase
ribo30 = [name[0].lower() + name[1:] for name in ribo30s]
ribo50 = [name[0].lower() + name[1:] for name in ribo50s]

reporter = ['mScarlet','PT5lac','msGFP2','Cre']

sc.pl.dotplot(adata, ribo50_vis, groupby='leiden', standard_scale = 'var', layer = 'corrected_umi', mean_only_expressed = True, save='ribo50_high_res.svg')

sc.pl.dotplot(adata, ribo30_vis, groupby='leiden', standard_scale = 'var', layer = 'corrected_umi', mean_only_expressed = True, save='ribo30_high_res.svg')

pspplot = sc.pl.dotplot(adata, psp_vis, groupby='leiden', layer = 'corrected_umi', mean_only_expressed=True, standard_scale='var', save='dotplot_high_res.svg')

sc.pl.dotplot(adata, ['arcA', 'arcB'], groupby='leiden', standard_scale = 'var', layer = 'corrected_umi', mean_only_expressed=True, save='arc_high_res.svg')

sc.pl.dotplot(adata, ['rplA', 'rplL', 'rplK', 'rplN', 'rpsE', 'rplR', 'rplV', 'rpsA', 'rpsB', 'rpsL', 'relA', 'spoT',
                      'hipA', 'recA', 'recN', 'sulA', 'aceE', 'aceF', 'lpd', 'acnB', 'fadE', 'gadA', 'gadB', 'gadC', 'gadE',
                      'gadX', 'hns', 'uspE', ],
              groupby='leiden', standard_scale = 'var', layer = 'corrected_umi', mean_only_expressed=True, save='selected_high_res.svg')

sc.pl.dotplot(adata, ['rplA', 'rplL', 'rplK', 'rplN', 'rpsE', 'rplR', 'rplV', 'rpsA', 'rpsB', 'rpsL', 'relA', 'spoT',
                      'hipA'],
              groupby='leiden', standard_scale = 'var', layer = 'corrected_umi', mean_only_expressed=True, save='selected_high_res_pt1.svg')

sc.pl.dotplot(adata, ['recA', 'recN', 'sulA', 'aceE', 'aceF', 'lpd', 'acnB', 'fadE', 'gadA', 'gadB', 'gadC', 'gadE',
                      'gadX', 'hns', 'uspE', ],
              groupby='leiden', standard_scale = 'var', layer = 'corrected_umi', mean_only_expressed=True, save='selected_high_res_pt2.svg')